/* JAI-Ext - OpenSource Java Advanced Image Extensions Library
*    http://www.geo-solutions.it/
*    Copyright 2014 GeoSolutions


* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.eclipse.imagen.media.contrastenhancement;

import java.awt.Rectangle;
import java.awt.image.DataBuffer;
import java.awt.image.Raster;
import java.awt.image.RenderedImage;
import java.awt.image.WritableRaster;
import java.util.Map;
import org.eclipse.imagen.ImageLayout;
import org.eclipse.imagen.PointOpImage;
import org.eclipse.imagen.RasterAccessor;
import org.eclipse.imagen.RasterFormatTag;
import org.eclipse.imagen.media.util.ImageUtil;

/**
 * An <code>OpImage</code> implementing a square root stretch operation.
 *
 * <p>This <code>OpImage</code> does this operation on the pixels values:
 *
 * <p>dest = SQRT(((src/(inMax - inMin)) * outMax) + outMin)
 */
final class SquareRootStretchOpImage extends PointOpImage {

    /** The inputMin value */
    protected int[] inputMin;

    /** The inputMax value */
    protected int[] inputMax;

    /** The outputMin value */
    protected int[] outputMin;

    /** The outputMax value */
    protected int[] outputMax;

    /**
     * Constructor.
     *
     * @param source The source image.
     * @param layout The destination image layout.
     * @param inputMin
     */
    public SquareRootStretchOpImage(
            RenderedImage source,
            Map config,
            ImageLayout layout,
            int[] inputMin,
            int[] inputMax,
            int[] outputMin,
            int[] outputMax) {
        super(source, layout, config, true);

        int numBands = getSampleModel().getNumBands();

        if (inputMin.length < numBands) {
            this.inputMin = new int[numBands];
            for (int i = 0; i < numBands; i++) {
                this.inputMin[i] = inputMin[0];
            }
        } else {
            this.inputMin = (int[]) inputMin.clone();
        }

        if (inputMax.length < numBands) {
            this.inputMax = new int[numBands];
            for (int i = 0; i < numBands; i++) {
                this.inputMax[i] = inputMax[0];
            }
        } else {
            this.inputMax = (int[]) inputMax.clone();
        }

        if (outputMin.length < numBands) {
            this.outputMin = new int[numBands];
            for (int i = 0; i < numBands; i++) {
                this.outputMin[i] = outputMin[0];
            }
        } else {
            this.outputMin = (int[]) outputMin.clone();
        }

        if (outputMax.length < numBands) {
            this.outputMax = new int[numBands];
            for (int i = 0; i < numBands; i++) {
                this.outputMax[i] = outputMax[0];
            }
        } else {
            this.outputMax = (int[]) outputMax.clone();
        }

        // Set flag to permit in-place operation.
        permitInPlaceOperation();
    }

    /**
     * Compute the square root stretching
     *
     * @param sources Cobbled sources, guaranteed to provide all the source data necessary for computing the rectangle.
     * @param dest The tile containing the rectangle to be computed.
     * @param destRect The rectangle within the tile to be computed.
     */
    protected void computeRect(Raster[] sources, WritableRaster dest, Rectangle destRect) {
        // Retrieve format tags.
        RasterFormatTag[] formatTags = getFormatTags();

        Rectangle srcRect = mapDestRect(destRect, 0);

        RasterAccessor dst = new RasterAccessor(dest, destRect, formatTags[1], getColorModel());
        RasterAccessor src = new RasterAccessor(
                sources[0], srcRect, formatTags[0], getSource(0).getColorModel());

        switch (dst.getDataType()) {
            case DataBuffer.TYPE_BYTE:
                computeRectByte(src, dst);
                break;
            case DataBuffer.TYPE_USHORT:
                computeRectUShort(src, dst);
                break;
            case DataBuffer.TYPE_SHORT:
                computeRectShort(src, dst);
                break;
            case DataBuffer.TYPE_INT:
                computeRectInt(src, dst);
                break;
                //        case DataBuffer.TYPE_FLOAT:
                //            computeRectFloat(src, dst);
                //            break;
                //        case DataBuffer.TYPE_DOUBLE:
                //            computeRectDouble(src, dst);
                //            break;
        }

        if (dst.needsClamping()) {

            /* Further clamp down to underlying raster data type. */
            dst.clampDataArrays();
        }
        dst.copyDataToRaster();
    }

    private void computeRectByte(RasterAccessor src, RasterAccessor dst) {
        int dstWidth = dst.getWidth();
        int dstHeight = dst.getHeight();
        int dstBands = dst.getNumBands();

        int dstLineStride = dst.getScanlineStride();
        int dstPixelStride = dst.getPixelStride();
        int[] dstBandOffsets = dst.getBandOffsets();
        byte[][] dstData = dst.getByteDataArrays();

        int srcLineStride = src.getScanlineStride();
        int srcPixelStride = src.getPixelStride();
        int[] srcBandOffsets = src.getBandOffsets();
        byte[][] srcData = src.getByteDataArrays();

        for (int b = 0; b < dstBands; b++) {
            float c = (float) inputMin[b];
            byte[] d = dstData[b];
            byte[] s = srcData[b];

            int dstLineOffset = dstBandOffsets[b];
            int srcLineOffset = srcBandOffsets[b];

            for (int h = 0; h < dstHeight; h++) {
                int dstPixelOffset = dstLineOffset;
                int srcPixelOffset = srcLineOffset;

                dstLineOffset += dstLineStride;
                srcLineOffset += srcLineStride;

                for (int w = 0; w < dstWidth; w++) {
                    d[dstPixelOffset] = ImageUtil.clampRoundByte((s[srcPixelOffset] & 0xFF) * c);

                    dstPixelOffset += dstPixelStride;
                    srcPixelOffset += srcPixelStride;
                }
            }
        }
    }

    private void computeRectUShort(RasterAccessor src, RasterAccessor dst) {
        int dstWidth = dst.getWidth();
        int dstHeight = dst.getHeight();
        int dstBands = dst.getNumBands();

        int dstLineStride = dst.getScanlineStride();
        int dstPixelStride = dst.getPixelStride();
        int[] dstBandOffsets = dst.getBandOffsets();
        short[][] dstData = dst.getShortDataArrays();

        int srcLineStride = src.getScanlineStride();
        int srcPixelStride = src.getPixelStride();
        int[] srcBandOffsets = src.getBandOffsets();
        short[][] srcData = src.getShortDataArrays();

        int val = 0;
        for (int b = 0; b < dstBands; b++) {
            short[] d = dstData[b];
            short[] s = srcData[b];

            int inMin = inputMin[b];
            int inMax = inputMax[b];
            int outMin = outputMin[b];
            int outMax = outputMax[b];

            int dstLineOffset = dstBandOffsets[b];
            int srcLineOffset = srcBandOffsets[b];

            for (int h = 0; h < dstHeight; h++) {
                int dstPixelOffset = dstLineOffset;
                int srcPixelOffset = srcLineOffset;

                dstLineOffset += dstLineStride;
                srcLineOffset += srcLineStride;

                for (int w = 0; w < dstWidth; w++) {
                    val = ((s[srcPixelOffset] >= inMin) ? (s[srcPixelOffset] - inMin) : 0);
                    d[dstPixelOffset] =
                            ImageUtil.clampRoundUShort((Math.sqrt((val / (inMax - inMin))) * outMax) + outMin);
                    dstPixelOffset += dstPixelStride;
                    srcPixelOffset += srcPixelStride;
                }
            }
        }
    }

    private void computeRectShort(RasterAccessor src, RasterAccessor dst) {
        int dstWidth = dst.getWidth();
        int dstHeight = dst.getHeight();
        int dstBands = dst.getNumBands();

        int dstLineStride = dst.getScanlineStride();
        int dstPixelStride = dst.getPixelStride();
        int[] dstBandOffsets = dst.getBandOffsets();
        short[][] dstData = dst.getShortDataArrays();

        int srcLineStride = src.getScanlineStride();
        int srcPixelStride = src.getPixelStride();
        int[] srcBandOffsets = src.getBandOffsets();
        short[][] srcData = src.getShortDataArrays();
        double val = 0;
        for (int b = 0; b < dstBands; b++) {
            int inMin = inputMin[b];
            int inMax = inputMax[b];
            int outMin = outputMin[b];
            int outMax = outputMax[b];

            short[] d = dstData[b];
            short[] s = srcData[b];

            int dstLineOffset = dstBandOffsets[b];
            int srcLineOffset = srcBandOffsets[b];

            for (int h = 0; h < dstHeight; h++) {
                int dstPixelOffset = dstLineOffset;
                int srcPixelOffset = srcLineOffset;

                dstLineOffset += dstLineStride;
                srcLineOffset += srcLineStride;

                for (int w = 0; w < dstWidth; w++) {
                    val = ((s[srcPixelOffset] >= inMin) ? (s[srcPixelOffset] - inMin) : 0);

                    d[dstPixelOffset] =
                            ImageUtil.clampRoundShort((Math.sqrt((val / (inMax - inMin))) * outMax) + outMin);

                    dstPixelOffset += dstPixelStride;
                    srcPixelOffset += srcPixelStride;
                }
            }
        }
    }

    private void computeRectInt(RasterAccessor src, RasterAccessor dst) {
        int dstWidth = dst.getWidth();
        int dstHeight = dst.getHeight();
        int dstBands = dst.getNumBands();

        int dstLineStride = dst.getScanlineStride();
        int dstPixelStride = dst.getPixelStride();
        int[] dstBandOffsets = dst.getBandOffsets();
        int[][] dstData = dst.getIntDataArrays();

        int srcLineStride = src.getScanlineStride();
        int srcPixelStride = src.getPixelStride();
        int[] srcBandOffsets = src.getBandOffsets();
        int[][] srcData = src.getIntDataArrays();

        double val = 0;
        for (int b = 0; b < dstBands; b++) {
            int inMin = inputMin[b];
            int inMax = inputMax[b];
            int outMin = outputMin[b];
            int outMax = outputMax[b];

            int[] d = dstData[b];
            int[] s = srcData[b];

            int dstLineOffset = dstBandOffsets[b];
            int srcLineOffset = srcBandOffsets[b];

            for (int h = 0; h < dstHeight; h++) {
                int dstPixelOffset = dstLineOffset;
                int srcPixelOffset = srcLineOffset;

                dstLineOffset += dstLineStride;
                srcLineOffset += srcLineStride;

                for (int w = 0; w < dstWidth; w++) {
                    val = ((s[srcPixelOffset] >= inMin) ? (s[srcPixelOffset] - inMin) : 0);

                    d[dstPixelOffset] = ImageUtil.clampRoundInt((Math.sqrt((val / (inMax - inMin))) * outMax) + outMin);

                    dstPixelOffset += dstPixelStride;
                    srcPixelOffset += srcPixelStride;
                }
            }
        }
    }

    //    private void computeRectFloat(RasterAccessor src,
    //                                  RasterAccessor dst) {
    //        int dstWidth = dst.getWidth();
    //        int dstHeight = dst.getHeight();
    //        int dstBands = dst.getNumBands();
    //
    //        int dstLineStride = dst.getScanlineStride();
    //        int dstPixelStride = dst.getPixelStride();
    //        int[] dstBandOffsets = dst.getBandOffsets();
    //        float[][] dstData = dst.getFloatDataArrays();
    //
    //        int srcLineStride = src.getScanlineStride();
    //        int srcPixelStride = src.getPixelStride();
    //        int[] srcBandOffsets = src.getBandOffsets();
    //        float[][] srcData = src.getFloatDataArrays();
    //
    //        for (int b = 0; b < dstBands; b++) {
    //            double c = inputMin[b];
    //            float[] d = dstData[b];
    //            float[] s = srcData[b];
    //
    //            int dstLineOffset = dstBandOffsets[b];
    //            int srcLineOffset = srcBandOffsets[b];
    //
    //            for (int h = 0; h < dstHeight; h++) {
    //                int dstPixelOffset = dstLineOffset;
    //                int srcPixelOffset = srcLineOffset;
    //
    //                dstLineOffset += dstLineStride;
    //                srcLineOffset += srcLineStride;
    //
    //                for (int w = 0; w < dstWidth; w++) {
    //                    d[dstPixelOffset] = ImageUtil.clampFloat(s[srcPixelOffset] * c);
    //
    //                    dstPixelOffset += dstPixelStride;
    //                    srcPixelOffset += srcPixelStride;
    //                }
    //            }
    //        }
    //    }
    //
    //    private void computeRectDouble(RasterAccessor src,
    //                                   RasterAccessor dst) {
    //        int dstWidth = dst.getWidth();
    //        int dstHeight = dst.getHeight();
    //        int dstBands = dst.getNumBands();
    //
    //        int dstLineStride = dst.getScanlineStride();
    //        int dstPixelStride = dst.getPixelStride();
    //        int[] dstBandOffsets = dst.getBandOffsets();
    //        double[][] dstData = dst.getDoubleDataArrays();
    //
    //        int srcLineStride = src.getScanlineStride();
    //        int srcPixelStride = src.getPixelStride();
    //        int[] srcBandOffsets = src.getBandOffsets();
    //        double[][] srcData = src.getDoubleDataArrays();
    //
    //        for (int b = 0; b < dstBands; b++) {
    //            double c = inputMin[b];
    //            double[] d = dstData[b];
    //            double[] s = srcData[b];
    //
    //            int dstLineOffset = dstBandOffsets[b];
    //            int srcLineOffset = srcBandOffsets[b];
    //
    //            for (int h = 0; h < dstHeight; h++) {
    //                int dstPixelOffset = dstLineOffset;
    //                int srcPixelOffset = srcLineOffset;
    //
    //                dstLineOffset += dstLineStride;
    //                srcLineOffset += srcLineStride;
    //
    //                for (int w = 0; w < dstWidth; w++) {
    //                    d[dstPixelOffset] = s[srcPixelOffset] * c;
    //
    //                    dstPixelOffset += dstPixelStride;
    //                    srcPixelOffset += srcPixelStride;
    //                }
    //            }
    //        }
    //    }
}
